'''
尝试下采样方法
'''


import os
import csv
import pandas as pd


# 读取数据
def read_data(data_path):
    data = []
    f = csv.reader(open(data_path, 'r'))
    for i in f:
        data.append(i)
    return data

# 读取文件列表
def divide_csv(data_list):

    class_one = []
    class_two = []
    class_three = []
    class_four = []
    class_five = []

    for i in range(1, len(data_list)):
        if int(data_list[i][5]) == 1:
            class_one.append(data_list[i])
        elif int(data_list[i][5]) == 2:
            class_two.append(data_list[i])
        elif int(data_list[i][5]) == 3:
            class_three.append(data_list[i])
        elif int(data_list[i][5]) == 4:
            class_four.append(data_list[i])
        elif int(data_list[i][5]) == 5:
            class_five.append(data_list[i])
        else:
            print("ERROR")
            break

    print(len(class_one) / 8.0)
    print(len(class_two) / 2.0)
    print(len(class_three))
    print(len(class_four))
    print(len(class_five))

    split_one_num = len(class_one) // 8
    split_two_num = len(class_two) // 2

    save_path = 'D:/lung_cancer/data/data_augmentation/divide_csv/under_sampling_five/'

    # 保存第一类
    one_path = save_path+'one/'
    if not os.path.exists(one_path):
        os.makedirs(one_path)
    for i in range(7):
        df = pd.DataFrame(class_one[i*split_one_num:(i+1)*split_one_num], columns=data_list[0])
        df.to_csv(one_path+str(i+1)+'.csv', index=False)
    df = pd.DataFrame(class_one[7 * split_one_num:], columns=data_list[0])
    df.to_csv(one_path + str(8) + '.csv', index=False)

    # 保存第二类
    two_path = save_path + 'two/'
    if not os.path.exists(two_path):
        os.makedirs(two_path)

    df = pd.DataFrame(class_two[0:split_two_num], columns=data_list[0])
    df.to_csv(two_path + str(1) + '.csv', index=False)
    df = pd.DataFrame(class_two[split_two_num:], columns=data_list[0])
    df.to_csv(two_path + str(2) + '.csv', index=False)

    # 保存第三类
    three_path = save_path + 'three/'
    if not os.path.exists(three_path):
        os.makedirs(three_path)

    df = pd.DataFrame(class_three, columns=data_list[0])
    df.to_csv(three_path + str(1) + '.csv', index=False)

    # 保存第四类
    four_path = save_path + 'four/'
    if not os.path.exists(four_path):
        os.makedirs(four_path)

    df = pd.DataFrame(class_four, columns=data_list[0])
    df.to_csv(four_path + str(1) + '.csv', index=False)

    # 保存第五类
    five_path = save_path + 'five/'
    if not os.path.exists(five_path):
        os.makedirs(five_path)

    df = pd.DataFrame(class_five, columns=data_list[0])
    df.to_csv(five_path + str(1) + '.csv', index=False)




if __name__ == '__main__':
    sets_path = 'D:/lung_cancer/data/data_augmentation/divide_csv/five/train.csv'
    data_list = read_data(sets_path)

    # 划分下采样数据集
    divide_csv(data_list)